-
Notifications
You must be signed in to change notification settings - Fork 31.2k
TF port of the Segment Anything Model (SAM) #22970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
| def flatten(input, start_dim=0, end_dim=-1): | ||
| # Replicates the behavior of torch.flatten in TF | ||
|
|
||
| # If end_dim or start_dim is negative, count them from the end | ||
| if end_dim < 0: | ||
| end_dim += input.shape.rank | ||
| if start_dim < 0: | ||
| start_dim += input.shape.rank | ||
|
|
||
| if start_dim == end_dim: | ||
| return input | ||
|
|
||
| in_shape = tf.shape(input) | ||
| flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1]) | ||
| out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0) | ||
| return tf.reshape(input, out_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🥲
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no idea why I didn't do this before now!
|
|
||
| return output_masks | ||
|
|
||
| def post_process_masks_tf( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we started including separate post-processing ops in native TensorFlow? I thought they were NumPy only. This is indeed nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't sure about this - there's probably some code duplication in the processor I can remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preprocessing are all in numpy - this hasn't been extended to postprocessing methods yet. Mainly because I haven't dared tackle torch.nn.functional.interpolate; partly because we haven't needed to yet.
That said - please don't have post_processing_xxx_tf! We don't use decode_tf for our tokenizers ;)
Could you rework the methods so there's a single post_process_xxx method and hidden framework-specifc methods? i.e.
def post_process_masks(self, masks, ...,):
if is_torch_tensor(masks):
return self._post_process_masks_pt(...)
if is_tf_tensor(masks):
return self._post_process_masks_tf(...)
...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! And sorry - I basically rushed through the processor code so I could get to the bit I was hype about (benchmarking GPT-4's translations)
b9dd5a4 to
b1f61bd
Compare
|
This is now almost ready to go and the code should be ready for review! Remaining issues:
|
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are there two different processing files, one of them not being imported everywhere?
The common tests should not be changed to have a higher tolerance, just override the right tests in proper test file.
Also cc @amyeroberts since you reviewed the PyTorch model extensively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why have two functions that do the exact same thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved as part of the general processor refactor!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like this is leftover from debugging...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved as part of the general processor refactor! (also oops, sorry)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shh, it's gone now. We don't talk about processing_tf_sam
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why have a separate test file to test the same class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also gone now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change the tolerance for this model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a tolerance argument to the base tests triggered the test to run in other models, which caused this test to fail. I'll investigate and see if it's necessary, though!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use more descriptive variable names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was copied straight from the PyTorch code, but on reflection I could probably refactor the whole thing out, because it was only there to deal with different memory orderings (whereas TensorFlow tensors are always contiguous and always have standard C memory ordering)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! I refactored the functional_layernorm function to handle alternate axes, and then just called that instead of this manual layernorm. Model output is unchanged and all integration tests still pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace this by?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarified that comment!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To address.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I never figured this out, but it's the same in Torch, and both models give equivalent outputs. @ArthurZucker do you know why this weight is non-trainable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't find any reference to this random embedding in the paper (in fact, the paper always mentions learned positional embeddings), but the same pattern is in the SAM codebase
This meme is all I can think of
|
Thanks for the review - about half of the comments relate to the processor code, which is definitely in need of a refactor, yes. Working on that now! |
amyeroberts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good!
Left some general comments - mainly wrt the processing code. I'd like for there to be as little TF/PT specific code if possible. For postprocessing it's OK, as a lot of postprocessing is still pytorch specific but for preprocessing it should be (as much as possible) framework agnostic.
For the processor, can you add pt_tf cross checks to make sure that TF postprocessed outputs are equivalent to the PT ones?
|
|
||
| return output_masks | ||
|
|
||
| def post_process_masks_tf( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preprocessing are all in numpy - this hasn't been extended to postprocessing methods yet. Mainly because I haven't dared tackle torch.nn.functional.interpolate; partly because we haven't needed to yet.
That said - please don't have post_processing_xxx_tf! We don't use decode_tf for our tokenizers ;)
Could you rework the methods so there's a single post_process_xxx method and hidden framework-specifc methods? i.e.
def post_process_masks(self, masks, ...,):
if is_torch_tensor(masks):
return self._post_process_masks_pt(...)
if is_tf_tensor(masks):
return self._post_process_masks_tf(...)
...| # overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise | ||
| # to generate masks during test | ||
| def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): | ||
| def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict, tol=1e-5): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to add the tol argument here? Unless necessary, I'd avoid resetting the tol default in all the methods so we only need to update in one place
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored this and reverted all the changes in the common tests
| if output_hidden_states: | ||
| vision_hidden_states = vision_outputs[1] | ||
| if output_attentions: | ||
| vision_attentions = vision_outputs[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we instead pass in return_dict=True to self.vision_encoder and then explicitly access the values from the names? I'm not a big fan of accessing from indexes here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! (Also changed in the original PT code)
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why have these arguments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not clear about this one! Aren't these arguments common across most of our models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so? Only SAM has get_image_embeddings and all other get_xxx_embeddings as far as I can tell just take self
|
|
||
| # Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only | ||
| # happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced | ||
| # it with an explicit shape check to avoid data-dependent control flow which breaks XLA. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:)
|
@amyeroberts @sgugger I refactored all the changes to the common tests, and just overrode |
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! 💪
Additional general comment: it seems like it is missing the Keras training argument all around (call and in the dropout layers)... but on the other hand, SAM is not trainable. Still, in case we add a training script, I'd add this quick future-proof change :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't find any reference to this random embedding in the paper (in fact, the paper always mentions learned positional embeddings), but the same pattern is in the SAM codebase
This meme is all I can think of
76cebb9 to
17536e4
Compare
|
@gante I think all comments are now addressed, and I added All comments from @amyeroberts and @sgugger should be addressed too - are you okay with going ahead and merging now once tests pass? |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work on this. @amyeroberts could you also have a look before this is merged?
| points (`torch.Tensor`, **optional**): | ||
| point coordinates and labels to embed. | ||
| boxes (`torch.Tensor`, **optionnal**): | ||
| boxes (`torch.Tensor`, **optional**): | ||
| boxes to embed | ||
| masks (`torch.Tensor`, **optionnal**): | ||
| masks (`torch.Tensor`, **optional**): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are touching this, can you put the optionals in italics and not bold ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
| return_dict=return_dict, | ||
| return_dict=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This cannot be forced as return_dict breaks jit compilation. This change needs reverting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad - this was my suggestion, sorry @Rocketknight1!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
| values. | ||
| """ | ||
|
|
||
| def __init__(self, config, downsample_rate=None, **kwargs) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The -> None make zero sense to me as a type annotation (I know it's what PEP says, but the init returns an instance of the class). Since there are no type annotations elsewhere, maybe just remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! (for all classes across both the PT and TF files)
amyeroberts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! 🔥
Thanks for iterating, and in particular for spending the time to add equivalence tests for the processor and keep the image processing code tidy with the two frameworks 🤗
|
|
||
| self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy())) | ||
|
|
||
| def test_image_processor_equivalence(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤗
| return_dict=return_dict, | ||
| return_dict=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad - this was my suggestion, sorry @Rocketknight1!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tolerance seems pretty high 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's actually okay - the values for the scores are very large (usually in the range 5-30). A tolerance of 2e-4 for numbers that big is quite tight!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note in #23376 - input_boxes should be a list of list of ints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is just copying from the PT implementation - but it would be great to add to the docstring info about what's returned as there's many objects
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll be honest that I don't understand it too well, lol. I'll leave that for a follow-up on the Torch end and copy the strings whenever they do it 😅
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so? Only SAM has get_image_embeddings and all other get_xxx_embeddings as far as I can tell just take self
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
latyer norm layer here should take eps from config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
layer norm layers here should take eps from config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PyTorch version doesn't, and just uses the 1e-6 default kwarg value!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's hope it's not too experimental 😬
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tnp has been around since 2.4, I think we're safe!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ha! for TF I doubt it ;)
Co-authored-by: Sylvain Gugger <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
875cc35 to
3902969
Compare
|
I think comments are addressed now - are we okay to merge? |
|
I'm treating silence as agreement, merging! |
This is a first draft of the SAM port - will update this PR as I port tests and make sure everything is working okay. It's also a first proof-of-concept for full GPT-4 auto-translation from PyTorch: The entire
modeling_tf_sam.pyfile was converted from PyTorch by GPT-4 with the exception of the imports at the top, because I haven't written a prompt for those yet.Update: I checked over all of the code and fixed the issues in the GPT port. Equivalence tests all look good! This is almost ready to merge, but there are a few small issues left:
channels_firstdoesn't actually work on CPU in TF